from typing import Any, Dict, List, Optional

from .core import (
    CallbackAuth,
    CpuSpecificationType,
    GpuType,
    LLMInferenceFramework,
    LLMSource,
    ModelEndpointStatus,
    ModelEndpointType,
    Quantization,
    StorageSpecificationType,
)
from .pydantic_types import BaseModel, Field, HttpUrl
from .rest import GetModelEndpointResponse
from .vllm import VLLMEndpointAdditionalArgs


class CreateLLMEndpointRequest(VLLMEndpointAdditionalArgs, BaseModel):
    name: str

    # LLM specific fields
    model_name: str
    source: LLMSource = LLMSource.HUGGING_FACE
    inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM
    inference_framework_image_tag: str = "latest"
    num_shards: int = 1
    """
    Number of shards to distribute the model onto GPUs.
    """

    quantize: Optional[Quantization] = None
    """
    Whether to quantize the model.
    """

    checkpoint_path: Optional[str] = None
    """
    Path to the checkpoint to load the model from.
    """

    # General endpoint fields
    metadata: Dict[str, Any]  # TODO: JSON type
    post_inference_hooks: Optional[List[str]] = None
    endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING
    cpus: Optional[CpuSpecificationType] = None
    gpus: Optional[int] = None
    memory: Optional[StorageSpecificationType] = None
    gpu_type: Optional[GpuType] = None
    storage: Optional[StorageSpecificationType] = None
    nodes_per_worker: Optional[int] = None
    optimize_costs: Optional[bool] = None
    min_workers: int
    max_workers: int
    per_worker: int
    labels: Dict[str, str]
    prewarm: Optional[bool] = None
    high_priority: Optional[bool] = None
    billing_tags: Optional[Dict[str, Any]] = None
    default_callback_url: Optional[HttpUrl] = None
    default_callback_auth: Optional[CallbackAuth] = None
    public_inference: Optional[bool] = True  # LLM endpoints are public by default.
    chat_template_override: Optional[str] = Field(
        default=None,
        description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint",
    )


class CreateLLMEndpointResponse(BaseModel):
    endpoint_creation_task_id: str


class GetLLMEndpointResponse(BaseModel):
    """
    Response object for retrieving a Model.
    """

    id: Optional[str] = Field(
        default=None,
        description="(For self-hosted users) The autogenerated ID of the model.",
    )
    """(For self-hosted users) The autogenerated ID of the model."""

    name: str = Field(
        description="The name of the model. Use this for making inference requests to the model."
    )
    """The name of the model. Use this for making inference requests to the model."""

    model_name: Optional[str] = Field(
        default=None,
        description="(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`.",
    )
    """(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`."""

    source: LLMSource = Field(description="The source of the model, e.g. Hugging Face.")
    """The source of the model, e.g. Hugging Face."""

    status: ModelEndpointStatus = Field(description="The status of the model.")
    """The status of the model (can be one of "READY", "UPDATE_PENDING", "UPDATE_IN_PROGRESS", "UPDATE_FAILED", "DELETE_IN_PROGRESS")."""

    inference_framework: LLMInferenceFramework = Field(
        description="The inference framework used by the model."
    )
    """(For self-hosted users) The inference framework used by the model."""

    inference_framework_tag: Optional[str] = Field(
        default=None,
        description="(For self-hosted users) The Docker image tag used to run the model.",
    )
    """(For self-hosted users) The Docker image tag used to run the model."""

    num_shards: Optional[int] = Field(
        default=None, description="(For self-hosted users) The number of shards."
    )
    """(For self-hosted users) The number of shards."""

    quantize: Optional[Quantization] = Field(
        default=None, description="(For self-hosted users) The quantization method."
    )
    """(For self-hosted users) The quantization method."""

    spec: Optional[GetModelEndpointResponse] = Field(
        default=None, description="(For self-hosted users) Model endpoint details."
    )
    """(For self-hosted users) Model endpoint details."""

    chat_template_override: Optional[str] = Field(
        default=None,
        description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint",
    )


class ListLLMEndpointsResponse(BaseModel):
    """
    Response object for listing Models.
    """

    model_endpoints: List[GetLLMEndpointResponse] = Field(
        ...,
        description="The list of models.",
    )
    """
    A list of Models, represented as `GetLLMEndpointResponse`s.
    """


class UpdateLLMEndpointRequest(VLLMEndpointAdditionalArgs, BaseModel):
    # LLM specific fields
    model_name: Optional[str] = None
    source: Optional[LLMSource] = None
    inference_framework_image_tag: Optional[str] = None
    num_shards: Optional[int] = None
    """
    Number of shards to distribute the model onto GPUs.
    """

    quantize: Optional[Quantization] = None
    """
    Whether to quantize the model.
    """

    checkpoint_path: Optional[str] = None
    """
    Path to the checkpoint to load the model from.
    """

    # General endpoint fields
    metadata: Optional[Dict[str, Any]] = None
    post_inference_hooks: Optional[List[str]] = None
    cpus: Optional[CpuSpecificationType] = None
    gpus: Optional[int] = None
    memory: Optional[StorageSpecificationType] = None
    gpu_type: Optional[GpuType] = None
    storage: Optional[StorageSpecificationType] = None
    optimize_costs: Optional[bool] = None
    min_workers: Optional[int] = None
    max_workers: Optional[int] = None
    per_worker: Optional[int] = None
    labels: Optional[Dict[str, str]] = None
    prewarm: Optional[bool] = None
    high_priority: Optional[bool] = None
    billing_tags: Optional[Dict[str, Any]] = None
    default_callback_url: Optional[HttpUrl] = None
    default_callback_auth: Optional[CallbackAuth] = None
    public_inference: Optional[bool] = None
    chat_template_override: Optional[str] = Field(
        default=None,
        description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint",
    )

    force_bundle_recreation: Optional[bool] = False
    """
    Whether to force recreate the underlying bundle.

    If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created
    that we would like to pick up for existing endpoints
    """


class UpdateLLMEndpointResponse(BaseModel):
    endpoint_creation_task_id: str


class DeleteLLMEndpointResponse(BaseModel):
    """
    Response object for deleting a Model.
    """

    deleted: bool = Field(..., description="Whether deletion was successful.")
    """
    Whether the deletion succeeded.
    """
